import fcmpy
from pygmm import GMM
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':
    # 训练x
    x1 = np.random.random((200, 2)) * 3 + 5
    x2 = np.hstack((np.random.random((100, 1)) + 2, np.random.random((100, 1)) * 2 - 4)) + 10
    x3 = np.hstack((np.random.random((100, 1)) * 2, np.random.random((100, 1)) * 2 - 2)) + 10
    x = np.vstack((x1, x2, x3))
    # 训练结果y
    y = np.zeros(400)
    # y[0:200] = 0
    y[200:300] = 1
    y[300:400] = 2

    e = 1E-6
    m, n = x.shape

    # 随机顺序
    ind = np.random.permutation(m)
    nx = x[ind, :]
    ny = y[ind]
    art2option = {
        'a':10,
        'b':10,
        'c':0.3,
        'd':0.75,
        'theta':0.1,
        'rho':0.997,
        'output':None,
        'e':1E-8
    }

    num,dim = nx.shape
    gmm = GMM(data=nx,dim=dim,method='ART2',ART2Option=art2option)
    gmm.em(data=nx,nsteps=100)
    testy = np.zeros(num)
    for ind_each in range(num):
        responses = [comp.pdf(nx[ind_each,:]) for comp in gmm.comps]
        testy[ind_each] = np.where(responses ==np.max(responses))[0]
    plt.figure(1)
    color_list_point = ['r.', 'b.', 'g.', 'y.', 'k.']
    color_list_circle = ['ro', 'bo', 'go', 'yo', 'ko']
    plt.title('原始分类')
    for j in range(int(max(y)) + 1):
        plt.plot(nx[abs(ny - j) < e, 0], nx[abs(ny - j) < e, 1], color_list_point[j])
    plt.figure(2)
    plt.title('ART2-GMM分类')
    for j in range(int(max(testy)) + 1):
        plt.plot(nx[abs(testy - j) < e, 0], nx[abs(testy - j) < e, 1], color_list_point[j])
    plt.show()